# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Default data formatting functions for experiments.

For new datasets, inherit form GenericDataFormatter and implement
all abstract functions.

These dataset-specific methods:
1) Define the column and input types for tabular dataframes used by model
2) Perform the necessary input feature engineering & normalisation steps
3) Reverts the normalisation for predictions
4) Are responsible for train, validation and test splits


"""

import abc
import enum


# Type defintions
class DataTypes(enum.IntEnum):
    """Defines numerical types of each column."""
    REAL_VALUED = 0
    CATEGORICAL = 1
    DATE = 2


class InputTypes(enum.IntEnum):
    """Defines input types of each column."""
    TARGET = 0
    OBSERVED_INPUT = 1
    KNOWN_INPUT = 2
    STATIC_INPUT = 3
    ID = 4  # Single column used as an entity identifier
    TIME = 5  # Single column exclusively used as a time index


class GenericDataFormatter(abc.ABC):
    """Abstract base class for all data formatters.

    User can implement the abstract methods below to perform dataset-specific
    manipulations.

    """

    @abc.abstractmethod
    def set_scalers(self, df):
        """Calibrates scalers using the data supplied."""
        raise NotImplementedError()

    @abc.abstractmethod
    def transform_inputs(self, df):
        """Performs feature transformation."""
        raise NotImplementedError()

    @abc.abstractmethod
    def format_predictions(self, df):
        """Reverts any normalisation to give predictions in original scale."""
        raise NotImplementedError()

    @abc.abstractmethod
    def split_data(self, df):
        """Performs the default train, validation and test splits."""
        raise NotImplementedError()

    @property
    @abc.abstractmethod
    def _column_definition(self):
        """Defines order, input type and data type of each column."""
        raise NotImplementedError()

    @abc.abstractmethod
    def get_fixed_params(self):
        """Defines the fixed parameters used by the model for training.

        Requires the following keys:
          'total_time_steps': Defines the total number of time steps used by TFT
          'num_encoder_steps': Determines length of LSTM encoder (i.e. history)
          'num_epochs': Maximum number of epochs for training
          'early_stopping_patience': Early stopping param for keras
          'multiprocessing_workers': # of cpus for data processing


        Returns:
          A dictionary of fixed parameters, e.g.:

          fixed_params = {
              'total_time_steps': 252 + 5,
              'num_encoder_steps': 252,
              'num_epochs': 100,
              'early_stopping_patience': 5,
              'multiprocessing_workers': 5,
          }
        """
        raise NotImplementedError

    # Shared functions across data-formatters
    @property
    def num_classes_per_cat_input(self):
        """Returns number of categories per relevant input.

        This is seqeuently required for keras embedding layers.
        """
        return self._num_classes_per_cat_input

    def get_num_samples_for_calibration(self):
        """Gets the default number of training and validation samples.

        Use to sub-sample the data for network calibration and a value of -1 uses
        all available samples.

        Returns:
          Tuple of (training samples, validation samples)
        """
        return -1, -1

    def get_column_definition(self):
        """"Returns formatted column definition in order expected by the TFT."""

        column_definition = self._column_definition

        # Sanity checks first.
        # Ensure only one ID and time column exist
        def _check_single_column(input_type):
            length = len([tup for tup in column_definition if tup[2] == input_type])

            if length != 1:
                raise ValueError('Illegal number of inputs ({}) of type {}'.format(
                    length, input_type))

        _check_single_column(InputTypes.ID)
        _check_single_column(InputTypes.TIME)

        identifier = [tup for tup in column_definition if tup[2] == InputTypes.ID]
        time = [tup for tup in column_definition if tup[2] == InputTypes.TIME]
        real_inputs = [
            tup for tup in column_definition if tup[1] == DataTypes.REAL_VALUED and
                                                tup[2] not in {InputTypes.ID, InputTypes.TIME}
        ]
        categorical_inputs = [
            tup for tup in column_definition if tup[1] == DataTypes.CATEGORICAL and
                                                tup[2] not in {InputTypes.ID, InputTypes.TIME}
        ]

        return identifier + time + real_inputs + categorical_inputs

    def _get_input_columns(self):
        """Returns names of all input columns."""
        return [
            tup[0]
            for tup in self.get_column_definition()
            if tup[2] not in {InputTypes.ID, InputTypes.TIME}
        ]

    def _get_tft_input_indices(self):
        """Returns the relevant indexes and input sizes required by TFT."""

        # Functions
        def _extract_tuples_from_data_type(data_type, defn):
            return [
                tup for tup in defn if tup[1] == data_type and
                                       tup[2] not in {InputTypes.ID, InputTypes.TIME}
            ]

        def _get_locations(input_types, defn):
            return [i for i, tup in enumerate(defn) if tup[2] in input_types]

        # Start extraction
        column_definition = [
            tup for tup in self.get_column_definition()
            if tup[2] not in {InputTypes.ID, InputTypes.TIME}
        ]

        categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL,
                                                            column_definition)
        real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED,
                                                     column_definition)

        locations = {
            'input_size':
                len(self._get_input_columns()),
            'output_size':
                len(_get_locations({InputTypes.TARGET}, column_definition)),
            'category_counts':
                self.num_classes_per_cat_input,
            'input_obs_loc':
                _get_locations({InputTypes.TARGET}, column_definition),
            'static_input_loc':
                _get_locations({InputTypes.STATIC_INPUT}, column_definition),
            'known_regular_inputs':
                _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT},
                               real_inputs),
            'known_categorical_inputs':
                _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT},
                               categorical_inputs),
        }

        return locations

    def get_experiment_params(self):
        """Returns fixed model parameters for experiments."""

        required_keys = [
            'total_time_steps', 'num_encoder_steps', 'num_epochs',
            'early_stopping_patience', 'multiprocessing_workers'
        ]

        fixed_params = self.get_fixed_params()

        for k in required_keys:
            if k not in fixed_params:
                raise ValueError('Field {}'.format(k) +
                                 ' missing from fixed parameter definitions!')

        fixed_params['column_definition'] = self.get_column_definition()

        fixed_params.update(self._get_tft_input_indices())

        return fixed_params
